## IMPORT

from pyspark.context import SparkContext, SparkConf
from pyspark.sql import SparkSession, Window, types
from pyspark.sql.types import StructType, IntegerType
from pyspark.sql.context import SQLContext
import pyspark.sql.functions as f

from math import floor, ceil
import sys, json
import pandas as pd

from splink.spark.spark_linker import SparkLinker
from splink.spark.spark_comparison_library import jaro_winkler_at_thresholds, levenshtein_at_thresholds, exact_match
from splink.charts import save_offline_chart

## SPARK SESSION 

sc = SparkContext.getOrCreate(conf=SparkConf())
spark = SparkSession(sc)
spark.sparkContext.setLogLevel('WARN')
sc.setCheckpointDir("/scratch/gpfs/ericmm/splink-demo32/temp_graphframes/")
spark.udf.registerJavaFunction('jaro_winkler','uk.gov.moj.dash.linkage.JaroWinklerSimilarity',types.DoubleType())

## LOAD DATA FROM SLURM PARAMS

state = ["AL", "AK", "AZ", "AR", "CA", "CO", "CT", "DE", "DC", "FL", "GA", "HI",
"ID", "IL", "IN", "IA", "KS", "KY", "LA", "ME", "MD", "MA", "MI", "MN", "MS", "MO",
"MT", "NE", "NV", "NH", "NJ", "NM", "NY", "NC", "ND", "OH", "OK", "OR", "PA", "RI",
"SC", "SD", "TN", "TX", "UT", "VT", "VA", "WA", "WV", "WI", "WY"][int(sys.argv[1])].lower()

year = int(sys.argv[2])

print(f"Working in {state} {year}...")

cl = spark.read.parquet(f"/scratch/gpfs/ericmm/affl/data/cl_in/{state}/vintage={year}").filter(f"address IS NOT NULL")
cl = cl.select('uid', 'gender', 'zip', 'city', 'address', 'first_only', 'last_only', 'first_m', 'last')
cl = cl.repartition(int(ceil(cl.count()/1000)))
print(f"There are {cl.count()} CoreLogic observations...")

vf = spark.read.parquet(f"/scratch/gpfs/ericmm/affl/data/l2/year={'2014' if year < 2014 else str(year)}")
vf = vf.filter(f"address IS NOT NULL AND state = \'{state}\'")
vf = vf.select('emmid', 'gender', 'zip', 'city', 'address', 'first_only', 'last_only', 'first_m', 'last_sf')
vf = vf.withColumnRenamed('last_sf', 'last').withColumnRenamed('emmid', 'uid')
vf = vf.repartition(int(ceil(vf.count()/1000)))
print(f"There are {vf.count()} L2 observations...")

## SETTINGS

no_comp = f" AND abs(l.gender - r.gender) < 1.001"

br1 = "l.zip = r.zip AND l.first_only = r.first_only" + no_comp
br2 = "l.zip = r.zip AND l.last_only = r.last_only" + no_comp
br3 = "l.last_only = r.last_only AND substring(l.first_only,1,2) = substring(r.first_only,1,2) AND l.city = r.city" + no_comp
br4 = "l.first_only = r.first_only AND substring(l.last_only,1,2) = substring(r.last_only,1,2) AND l.city = r.city" + no_comp
br5 = "l.address = r.address AND substring(l.first_only,1,1) = substring(r.first_only,1,1) AND l.city = r.city" + no_comp
br6 = "l.address = r.address AND substring(l.last_only,1,1) = substring(r.last_only,1,1) AND l.city = r.city" + no_comp

settings = {
  
  "link_type": "link_only",
  
  "blocking_rules_to_generate_predictions": [
    {"blocking_rule": br1, "salting_partitions": 4},
    {"blocking_rule": br2, "salting_partitions": 4},
    {"blocking_rule": br3, "salting_partitions": 4},
    {"blocking_rule": br4, "salting_partitions": 4},
    {"blocking_rule": br5, "salting_partitions": 1},
    {"blocking_rule": br6, "salting_partitions": 1},
  ],
  
  "comparisons": [
        jaro_winkler_at_thresholds(
          "first_m",
          [0.94, 0.9, 0.86],
          include_exact_match_level = True, # switch
          term_frequency_adjustments = True
          ),
        jaro_winkler_at_thresholds(
          "last",
          [0.94, 0.9, 0.86],
          include_exact_match_level = True, # switch
          term_frequency_adjustments = True
          ),
        jaro_winkler_at_thresholds(
          "address",
          [0.94, 0.9, 0.86],
          include_exact_match_level = True,
          term_frequency_adjustments = False
          ),
        jaro_winkler_at_thresholds(
          "city",
          [0.94,0.88], # 0.9 if error
          include_exact_match_level = False,
          term_frequency_adjustments = True
        )
    ],
  
  "retain_matching_columns": True,
  "retain_intermediate_calculation_columns": False,
  "unique_id_column_name": "uid",
  "max_iterations": 40
  
}

linker = SparkLinker(input_table_or_tables = [cl,vf], settings_dict = settings, spark = spark)

linker.estimate_u_using_random_sampling(target_rows = vf.count())

linker.estimate_parameters_using_expectation_maximisation(
  "l.zip = r.zip AND l.first_only = r.first_only",
  comparisons_to_deactivate = ["city", "first_m"],
  comparison_levels_to_reverse_blocking_rule = [
    linker._settings_obj \
    ._get_comparison_by_output_column_name("city") \
    ._get_comparison_level_by_comparison_vector_value(1),
    linker._settings_obj \
    ._get_comparison_by_output_column_name("first_m") \
    ._get_comparison_level_by_comparison_vector_value(1)
  ],
  fix_u_probabilities=True
)

linker.estimate_parameters_using_expectation_maximisation(
  "l.zip = r.zip AND l.last_only = r.last_only",
  comparisons_to_deactivate = ["city", "last"],
  comparison_levels_to_reverse_blocking_rule = [
    linker._settings_obj \
    ._get_comparison_by_output_column_name("city") \
    ._get_comparison_level_by_comparison_vector_value(1),
    linker._settings_obj \
    ._get_comparison_by_output_column_name("last") \
    ._get_comparison_level_by_comparison_vector_value(1)
  ],
  fix_u_probabilities=True
)

linker.estimate_parameters_using_expectation_maximisation(
  "l.last_only = r.last_only AND l.first_only = r.first_only",
  comparisons_to_deactivate = ["last", "first_m"],
  comparison_levels_to_reverse_blocking_rule = [
    linker._settings_obj \
    ._get_comparison_by_output_column_name("last") \
    ._get_comparison_level_by_comparison_vector_value(1),
    linker._settings_obj \
    ._get_comparison_by_output_column_name("first_m") \
    ._get_comparison_level_by_comparison_vector_value(1)
  ],
  fix_u_probabilities=True
)

linker.save_settings_to_json(f"/scratch/gpfs/ericmm/affl/logs/{state}_{year}_cl_l2_addr.json", overwrite = True)
save_offline_chart(linker.match_weights_chart().spec, f"/scratch/gpfs/ericmm/affl/logs/{state}_{year}_cl_l2_addr.html", overwrite = True)

df = linker.predict(threshold_match_probability = 0.5) \
  .as_spark_dataframe() \
  .select("match_probability", "uid_l", "uid_r", "gamma_last", "gamma_first_m", "gamma_address") \
  .filter("(gamma_last > 0 AND gamma_first_m > 0) OR (gamma_last = -1 AND gamma_first_m = 4) OR (gamma_last = 4 AND gamma_first_m = -1)") \
  .filter("gamma_address > 0")

print(f"There are {df.count()} predicted matches with posterior >= 0.5")
df.coalesce(20).write.mode("overwrite").parquet(f"/scratch/gpfs/ericmm/affl/data/cl_l2_matches/state={state}/year={year}/type=a/")
